{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Custom Keras Layer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Idea:\n",
    "\n",
    "We build a custom activation layer called **Antirectifier**,\n",
    "which modifies the shape of the tensor that passes through it.\n",
    "\n",
    "We need to specify two methods: `get_output_shape_for` and `call`.\n",
    "\n",
    "Note that the same result can also be achieved via a `Lambda` layer (`keras.layer.core.Lambda`)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```python\n",
    "\n",
    "keras.layers.core.Lambda(function, output_shape=None, arguments=None)\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Because our custom layer is written with primitives from the Keras backend (`K`), our code can run both on TensorFlow and Theano."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.models import Sequential\n",
    "from keras.layers import Dense, Dropout, Layer, Activation\n",
    "from keras.datasets import mnist\n",
    "from keras import backend as K\n",
    "from keras.utils import np_utils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## AntiRectifier Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Antirectifier(Layer):\n",
    "    '''This is the combination of a sample-wise\n",
    "    L2 normalization with the concatenation of the\n",
    "    positive part of the input with the negative part\n",
    "    of the input. The result is a tensor of samples that are\n",
    "    twice as large as the input samples.\n",
    "\n",
    "    It can be used in place of a ReLU.\n",
    "\n",
    "    # Input shape\n",
    "        2D tensor of shape (samples, n)\n",
    "\n",
    "    # Output shape\n",
    "        2D tensor of shape (samples, 2*n)\n",
    "\n",
    "    # Theoretical justification\n",
    "        When applying ReLU, assuming that the distribution\n",
    "        of the previous output is approximately centered around 0.,\n",
    "        you are discarding half of your input. This is inefficient.\n",
    "\n",
    "        Antirectifier allows to return all-positive outputs like ReLU,\n",
    "        without discarding any data.\n",
    "\n",
    "        Tests on MNIST show that Antirectifier allows to train networks\n",
    "        with twice less parameters yet with comparable\n",
    "        classification accuracy as an equivalent ReLU-based network.\n",
    "    '''\n",
    "\n",
    "    def compute_output_shape(self, input_shape):\n",
    "        shape = list(input_shape)\n",
    "        assert len(shape) == 2  # only valid for 2D tensors\n",
    "        shape[-1] *= 2\n",
    "        return tuple(shape)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        inputs -= K.mean(inputs, axis=1, keepdims=True)\n",
    "        inputs = K.l2_normalize(inputs, axis=1)\n",
    "        pos = K.relu(inputs)\n",
    "        neg = K.relu(-inputs)\n",
    "        return K.concatenate([pos, neg], axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Parametrs and Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# global parameters\n",
    "batch_size = 128\n",
    "nb_classes = 10\n",
    "nb_epoch = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data Preparation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "60000 train samples\n",
      "10000 test samples\n"
     ]
    }
   ],
   "source": [
    "# the data, shuffled and split between train and test sets\n",
    "(X_train, y_train), (X_test, y_test) = mnist.load_data()\n",
    "\n",
    "X_train = X_train.reshape(60000, 784)\n",
    "X_test = X_test.reshape(10000, 784)\n",
    "X_train = X_train.astype('float32')\n",
    "X_test = X_test.astype('float32')\n",
    "X_train /= 255\n",
    "X_test /= 255\n",
    "print(X_train.shape[0], 'train samples')\n",
    "print(X_test.shape[0], 'test samples')\n",
    "\n",
    "# convert class vectors to binary class matrices\n",
    "Y_train = np_utils.to_categorical(y_train, nb_classes)\n",
    "Y_test = np_utils.to_categorical(y_test, nb_classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model with Custom Layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 60000 samples, validate on 10000 samples\n",
      "Epoch 1/10\n",
      "60000/60000 [==============================] - 4s - loss: 0.6029 - acc: 0.9154 - val_loss: 0.1556 - val_acc: 0.9612\n",
      "Epoch 2/10\n",
      "60000/60000 [==============================] - 3s - loss: 0.1252 - acc: 0.9662 - val_loss: 0.0990 - val_acc: 0.9714\n",
      "Epoch 3/10\n",
      "60000/60000 [==============================] - 3s - loss: 0.0813 - acc: 0.9766 - val_loss: 0.0796 - val_acc: 0.9758\n",
      "Epoch 4/10\n",
      "60000/60000 [==============================] - 3s - loss: 0.0634 - acc: 0.9810 - val_loss: 0.0783 - val_acc: 0.9747\n",
      "Epoch 5/10\n",
      "60000/60000 [==============================] - 3s - loss: 0.0513 - acc: 0.9847 - val_loss: 0.0685 - val_acc: 0.9792\n",
      "Epoch 6/10\n",
      "60000/60000 [==============================] - 3s - loss: 0.0428 - acc: 0.9867 - val_loss: 0.0669 - val_acc: 0.9792\n",
      "Epoch 7/10\n",
      "60000/60000 [==============================] - 3s - loss: 0.0381 - acc: 0.9885 - val_loss: 0.0668 - val_acc: 0.9799\n",
      "Epoch 8/10\n",
      "60000/60000 [==============================] - 3s - loss: 0.0314 - acc: 0.9903 - val_loss: 0.0672 - val_acc: 0.9790\n",
      "Epoch 9/10\n",
      "60000/60000 [==============================] - 3s - loss: 0.0276 - acc: 0.9913 - val_loss: 0.0616 - val_acc: 0.9817\n",
      "Epoch 10/10\n",
      "60000/60000 [==============================] - 3s - loss: 0.0238 - acc: 0.9926 - val_loss: 0.0608 - val_acc: 0.9825\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7f2c140fbac8>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# build the model\n",
    "model = Sequential()\n",
    "model.add(Dense(256, input_shape=(784,)))\n",
    "model.add(Antirectifier())\n",
    "model.add(Dropout(0.1))\n",
    "model.add(Dense(256))\n",
    "model.add(Antirectifier())\n",
    "model.add(Dropout(0.1))\n",
    "model.add(Dense(10))\n",
    "model.add(Activation('softmax'))\n",
    "\n",
    "# compile the model\n",
    "model.compile(loss='categorical_crossentropy',\n",
    "              optimizer='rmsprop',\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "# train the model\n",
    "model.fit(X_train, Y_train,\n",
    "          batch_size=batch_size, epochs=nb_epoch,\n",
    "          verbose=1, validation_data=(X_test, Y_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Excercise"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compare with an equivalent network that is **2x bigger** (in terms of Dense layers) + **ReLU**)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "## your code here"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
